-
-
Notifications
You must be signed in to change notification settings - Fork 98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Input Wrapper #883
Input Wrapper #883
Conversation
b4ad71e
to
52b6264
Compare
To add info on this, this is how I was able to define a unet: (Note: in this case I was using a version of dfdx that had some other local changes, specially experimental ones.) rust code
#[input_wrapper]
#[derive(Clone, Debug)]
pub struct Split<Forward, Skip> {
pub forward: Forward,
pub skip: Skip,
}
impl<Forward, Skip, const AXIS: isize> TryConcatTensorAlong<Axis<AXIS>> for Split<Forward, Skip>
where
(Forward, Skip): TryConcatTensorAlong<Axis<AXIS>>,
{
type Output = <(Forward, Skip) as TryConcatTensorAlong<Axis<AXIS>>>::Output;
fn try_concat_tensor_along(self, ax: Axis<AXIS>) -> Result<Self::Output, Error> {
let (forward, skip) = self.into();
(forward, skip).try_concat_tensor_along(ax)
}
}
/// From:
/// ```ignore
/// batch * CH_IN * height * width
/// ```
///
/// To:
/// ```ignore
/// batch * CH_OUT * height * width
/// ```
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(ConvBlock)]
pub struct ConvBlockConfig<const CH_IN: usize, const CH_OUT: usize> {
pub conv_1: Conv2DConstConfig<CH_IN, CH_OUT, 3, 1, 1>,
pub norm_1: BatchNorm2DConstConfig<CH_OUT>,
pub a_1: ReLU,
//
pub conv_2: Conv2DConstConfig<CH_OUT, CH_OUT, 3, 1, 1>,
pub norm_2: BatchNorm2DConstConfig<CH_OUT>,
pub a_2: ReLU,
}
/// From:
/// ```ignore
/// batch * CH_IN * height * width
/// ```
///
/// To:
/// ```ignore
/// Split {
/// forward: batch * CH_OUT * height/2 * width/2,
/// skip: batch * CH_OUT * height * width,
/// }
/// ```
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(DownBlock)]
pub struct DownBlockConfig<const CH_IN: usize, const CH_OUT: usize> {
pub conv: ConvBlockConfig<CH_IN, CH_OUT>,
//
pub split: SplitInto<(Id, Id)>,
pub wrapper: split::FromTuple,
pub pool: On<split::forward, MaxPool2DConst<2, 2, 0>>,
}
/// From:
/// ```ignore
/// Split {
/// forward: batch * CH_INF * height/2 * width/2,
/// skip: batch * CH_INS * height * width,
/// }
/// ```
///
/// To:
/// ```ignore
/// batch * CH_OUT * height * width
/// ```
///
/// Notes:
/// - `CH_INF` refers to the #channels from [`Split::forward`].
/// - `CH_INS` refers to the #channels from [`Split::skip`], but this parameter is not directly passed to this structure.
/// - `CH_CONCAT` is supposed to be `CH_OUT + CH_INS`.
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(UpBlock)]
pub struct UpBlockConfig<const CH_INF: usize, const CH_OUT: usize, const CH_CONCAT: usize> {
// for keras' padding='same', the PADDING value must be set to:
// ((kernel-1) * dilation + 1) // 2 = ((3-1) * 1 + 1 // 2) = 1
pub conv_trans:
On<split::forward, ConvTrans2DConstConfig<CH_INF, CH_OUT, 3, 2, 1, 1, 1, 1>>,
pub bias: On<split::forward, Bias2DConstConfig<CH_OUT>>,
// concat "skip" and "forward" along channels
pub tuple: split::IntoTuple,
pub concat: ops::ConcatTensorAlong<Axis<1>>,
pub conv: ConvBlockConfig<CH_CONCAT, CH_OUT>,
}
/// Just applies `M`.
type Onc0<M> = M;
/// Access `F` and then applies `M`.
type Onc1<F, M> = On<F, M>;
/// Access `F` consecutively 2 times and then applies `M`.
type Onc2<F, M> = On<F, On<F, M>>;
/// Access `F` consecutively 3 times and then applies `M`.
type Onc3<F, M> = On<F, Onc2<F, M>>;
/// Access `F` consecutively 4 times and then applies `M`.
type Onc4<F, M> = On<F, Onc3<F, M>>;
/// Access `F` consecutively 5 times and then applies `M`.
type Onc5<F, M> = On<F, Onc4<F, M>>;
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(Model)]
pub struct ModelConfig {
// encoder
pub down_block_0: Onc0<DownBlockConfig<3, 32>>,
pub down_block_1: Onc1<split::forward, DownBlockConfig<32, 64>>,
pub down_block_2: Onc2<split::forward, DownBlockConfig<64, 128>>,
pub down_block_3: Onc3<split::forward, DownBlockConfig<128, 256>>,
// bottleneck
// note: this increases channels but does not reduces the height nor the width
pub conv_bottle: Onc4<split::forward, ConvBlockConfig<256, 512>>,
// decoder
pub up_block_4: Onc3<split::forward, UpBlockConfig<512, 256, 512>>,
pub up_block_3: Onc2<split::forward, UpBlockConfig<256, 128, 256>>,
pub up_block_2: Onc1<split::forward, UpBlockConfig<128, 64, 128>>,
pub up_block_1: Onc0<UpBlockConfig<64, 32, 64>>,
// yclass channel conversion
pub conv_2: Conv2DConstConfig<32, 13, 1, 1, 0>,
pub bias_2: Bias2DConstConfig<13>,
} And this is how I defined and trained a simple RNN (based on this exercise). rust code
pub mod model {
use super::*;
#[input_wrapper]
#[derive(Clone, Debug)]
pub struct Input<A, X> {
pub a_prev: A,
pub x: X,
}
impl<A, X, const AXIS: isize> TryConcatTensorAlong<Axis<AXIS>> for Input<A, X>
where
(A, X): TryConcatTensorAlong<Axis<AXIS>>,
{
type Output = <(A, X) as TryConcatTensorAlong<Axis<AXIS>>>::Output;
fn try_concat_tensor_along(self, ax: Axis<AXIS>) -> Result<Self::Output, Error> {
(self.a_prev, self.x).try_concat_tensor_along(ax)
}
}
#[input_wrapper]
#[derive(Clone, Debug)]
pub struct Output<A, Y> {
pub a: A,
pub y: Y,
}
impl<AS: Shape, YS: Shape, E: Dtype, D: Device<E>>
Output<Tensor<AS, E, D, OwnedTape<E, D>>, Tensor<YS, E, D, OwnedTape<E, D>>>
{
pub fn merge_tapes_on_y(self) -> Self {
let (a, at) = self.a.split_tape();
let (y, yt) = self.y.split_tape();
Self {
a: a.leaky_traced(),
y: y.put_tape(at.merge(yt)),
}
}
}
//
/// Input:
/// ```ignore
/// Input {
/// a_prev: A,
/// x: X,
/// }
/// ```
///
/// Output:
/// ```ignore
/// Output {
/// a: A,
/// y: Y,
/// }
/// ```
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(Cell)]
pub struct CellConstConfig<
const NA: usize,
const NX: usize,
const NY: usize,
const CONCAT_AXIS: isize,
const NAPNX: usize = { NA + NX },
> {
// doing concat(a_prev, x) dot concat(wa^t, wb^t) + b is the same as
// doing a_prev dot wa^t + x dot wb^t + b
//
// pub amul: On<input::a_prev, MatMulConstConfig<NA, NA>>,
// pub xmul: On<input::x, MatMulConstConfig<NX, NA>>,
// pub ax_tuple: input::IntoTuple,
// pub ax_add: ops::Add,
// pub bias: Bias1DConstConfig<NA>,
//
pub concat_input: ops::ConcatTensorAlong<Axis<CONCAT_AXIS>>,
pub ax_linear: LinearConstConfig<NAPNX, NA>,
//
pub g1: Tanh,
pub ay_tuple: SplitInto<(Id, Id)>,
pub ay: output::FromTuple,
pub ylinear: On<output::y, LinearConstConfig<NA, NY>>,
}
}
#[test]
fn test_rnn() -> anyhow::Result<()> {
let dev = Cuda::try_build(0, 0)?;
const T_: usize = 2;
const BATCH: usize = 3;
// for unbatched (1D) tensors, the concat axis is 0
// for batched (2D) tensors, the concat axis is 1
const CONCAT_AXIS: isize = 1;
const NA: usize = 2;
const NX: usize = 3;
const NY: usize = 3;
type XT<T = NoneTape> = Tensor<Rank2<BATCH, NX>, f32, Device_, T>;
type AT<T = NoneTape> = Tensor<Rank2<BATCH, NA>, f32, Device_, T>;
type YT<T = NoneTape> = Tensor<Rank2<BATCH, NY>, f32, Device_, T>;
let mut model =
dev.build_module::<f32>(model::CellConstConfig::<NA, NX, NY, CONCAT_AXIS>::default());
let mut grads = model.alloc_grads();
let mut opt = dfdx::prelude::optim::Adam::new(
&model,
AdamConfig {
lr: 1e-4,
..Default::default() // weight_decay: Some(dfdx::nn::optim::WeightDecay::L2(0.001)),
},
);
const EPOCHS: usize = 2;
for e in 0..EPOCHS {
let a_prev: AT = dev.zeros();
let mut x: XT = dev.zeros();
let mut a_prev_t: AT<_> = a_prev.leaky_traced();
let mut batch_loss = 0f32;
for _t in 0..T_ {
let y_t: YT = dev.sample_uniform();
let x_t: XT<OwnedTape<f32, Device_>> = x.traced(grads);
let input = model::Input {
a_prev: a_prev_t,
x: x_t,
};
let prediction = model.forward_mut(input);
let prediction = prediction.merge_tapes_on_y();
let loss_t =
dfdx::losses::cross_entropy_with_logits_loss(prediction.y, y_t.clone());
batch_loss += loss_t.array();
// Note:
// Running backprop and model update for each timestep t.
// A different approach would be to run backprop at the last timestep and update once.
// Or yet do something in between.
grads = loss_t.backward();
opt.update(&mut model, &grads).unwrap();
x = y_t;
a_prev_t = prediction.a;
}
println!("epoch: {}; loss: {}", e, batch_loss);
// grads.drop_non_leafs();
model.zero_grads(&mut grads);
}
Ok(())
} |
- Add the heck dep to convert from CamelCase into snake_case. - Add layers. - `Id`, which just forwards the input. - `On`, applies some Module into an input wrapper field. - Contains a test demonstrating it's usage. - `Add`, which calls `try_add` for the inputs.
52b6264
to
82c314b
Compare
I'll prioritize moving this experiment to a separate crate, but feel free to ping in case anyone have some question or suggestion. |
This is a draft, closes #878.
Note: If this design ends up being useful, this could be implemented as a separated library (there are only code additions and they don't conflict with anything), but for perhaps feedback, it's better and more straight-forward to currently have this as a draft PR.
#[input_wrapper]
.Id
, which just forwards the input.On
, applies some Module into an input wrapper field.Add
, which callstry_add
for the inputs.This is how it gets used:
dfdx/dfdx/src/nn/layers/on.rs
Lines 41 to 64 in 52b6264
This is what gets generated from the above:
rust code